/******************************************************************************
 * Copyright (c) 2020 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to
 * deal in the Software without restriction, including without limitation the
 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
 * sell copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 * IN THE SOFTWARE.
 *****************************************************************************/
#include "random_access_tester.hpp"

#include <roc_shmem.hpp>

/******************************************************************************
 * DEVICE TEST KERNEL
 *****************************************************************************/

__device__ bool
thread_passing(int num_bins,
               uint32_t *bin_threads,
               uint32_t * off_bins,
               uint32_t *PE_bins,
               int *offset,
               int *PE,
               int coal_coef,
               int size)
{
    bool pass = false;
    int wave_id = ((hipThreadIdx_x + hipBlockIdx_x * hipBlockDim_x) /64) ;//get_global_wave_id();

    int off = wave_id * num_bins;
    for(int i =0; i< num_bins; i++){
        if(((hipThreadIdx_x%64) >= bin_threads[i+off]) &&
           ((hipThreadIdx_x%64) < bin_threads[i+off]+ coal_coef)){
                pass = true;
                *offset = off_bins[i+off] + (((hipThreadIdx_x%64) -
                                            bin_threads[i+off])* size) ;
                *PE = PE_bins[i+off];

        }
    }
    return pass;
}

__global__ void
RandomAccessTest(int loop,
               int skip,
               uint64_t *timer,
               int *s_buf,
               int *r_buf,
               int size,
               OpType type,
               int coal_coef,
               int num_bins,
               int num_waves,
               uint32_t * threads_bins,
               uint32_t * off_bins,
               uint32_t *PE_bins,
               ShmemContextType ctx_type)
{
    uint64_t start;
    __shared__ roc_shmem_ctx_t ctx;
    roc_shmem_wg_init();
    roc_shmem_wg_ctx_create(ctx_type, &ctx);

    int pe = roc_shmem_my_pe(ctx);
    int offset;
    int PE;

    if(thread_passing(num_bins, threads_bins, off_bins, PE_bins, &offset,
                      &PE, coal_coef, (size/sizeof(int)))==true){

        s_buf = s_buf + offset;
        r_buf = r_buf + offset;

        for(int i = 0; i < loop+skip; i++) {
            if(i == skip)
                start = roc_shmem_timer();
            switch (type) {
                case GetType:
                    roc_shmem_getmem(ctx, r_buf, s_buf, size, PE);
                    break;
                case PutType:
                    roc_shmem_putmem(ctx, (char*)r_buf, (char*)s_buf, size, PE);
                    break;
                default:
                    break;
            }
        }

        roc_shmem_quiet(ctx);

        atomicAdd((unsigned long long *) &timer[hipBlockIdx_x],
                    roc_shmem_timer() - start);
    }
    roc_shmem_wg_ctx_destroy(ctx);
    roc_shmem_wg_finalize();

}



/******************************************************************************
 * HOST HELPER FUNCTIONS
 ****************************************************************************/
__host__ void init_bins(int num_bins, int num_waves, uint32_t* off_bins,
                        uint32_t *threads_bins, uint32_t * PE_bins, int size,
                        int coal_coef, int num_pes, int max_size)
{
    srand(time(NULL));

    for (int j = 0; j < num_waves; j++) {
        for (int i = 0; i < num_bins; i++) {
            int current_bin = j * num_bins + i;
            assert((64 % num_bins) == 0);
            int quad_size = 64 / num_bins;
            int allowed_index_range = quad_size - coal_coef - 1;
            assert(allowed_index_range >= 0);
            int rand_val = 0;

            if (allowed_index_range)
                rand_val = rand() % allowed_index_range;

            threads_bins[current_bin] = rand_val + i * quad_size;

            quad_size = max_size / (num_bins + num_waves);
            rand_val = rand() % (quad_size - (size * coal_coef - 1));
            off_bins[current_bin] = rand_val + current_bin * quad_size;

            PE_bins[current_bin] = rand() % num_pes;
        }
    }
}

/******************************************************************************
 * HOST TESTER CLASS METHODS
 *****************************************************************************/
RandomAccessTester::RandomAccessTester(TesterArguments args)
    : Tester(args)
{
    int max_size = args.max_msg_size;
    int wg_size = args.wg_size;
    _num_waves = (args.wg_size/64) * args.num_wgs;
    _num_bins = args.thread_access / args.coal_coef;
    assert((args.wg_size/64)<=1);

    s_buf = (int *)roc_shmem_malloc( max_size *  wg_size * space);
    r_buf = (int *)roc_shmem_malloc( max_size *  wg_size * space);
    h_buf = (int *)malloc(max_size * wg_size * space);
    h_dev_buf = (int *)malloc(max_size * wg_size * space );
    hipMalloc((void**)&_threads_bins, sizeof(uint32_t)* _num_waves * _num_bins);
    hipMalloc((void**)&_off_bins, sizeof(uint32_t)*_num_waves * _num_bins);
    hipMalloc((void**)&_PE_bins, sizeof(uint32_t)*_num_waves * _num_bins);
    memset(_threads_bins, 0,  sizeof(uint32_t)* _num_waves * _num_bins);
    memset(_off_bins, 0,  sizeof(uint32_t)* _num_waves * _num_bins);
    memset(_PE_bins, 0,  sizeof(uint32_t)* _num_waves * _num_bins);



}

RandomAccessTester::~RandomAccessTester()
{
    roc_shmem_free(s_buf);
    roc_shmem_free(r_buf);
    free(h_buf);
    free(h_dev_buf);
    hipFree(_threads_bins);
    hipFree(_off_bins);
    hipFree(_PE_bins);
}


void
RandomAccessTester::resetBuffers()
{
    for(int i=0; i < args.max_msg_size/sizeof(int) * args.wg_size * space; i++){
        s_buf[i] =1;
        r_buf[i] =0;
        h_buf[i] =0;
    }
}

void
RandomAccessTester::launchKernel(dim3 gridSize,
                                 dim3 blockSize,
                                 int loop,
                                 uint64_t size)
{
    size_t shared_bytes;
    roc_shmem_dynamic_shared(&shared_bytes);
    int _thread_access = args.thread_access;
    int _coal_coef = args.coal_coef;

    assert(_coal_coef >= 1);
    assert(gridSize.x == 1 && gridSize.y == 1 && gridSize.z == 1);

    init_bins(_num_bins, _num_waves, _off_bins, _threads_bins,
              _PE_bins, size/sizeof(int), _coal_coef, args.numprocs,
              (space * size * args.wg_size )  /sizeof(int));

    if(args.myid ==0){
        hipLaunchKernelGGL(RandomAccessTest,
                       gridSize,
                       blockSize,
                       shared_bytes,
                       stream,
                       loop,
                       args.skip,
                       timer,
                       s_buf,
                       r_buf,
                       size,
                       (OpType) args.op_type,
                       _coal_coef,
                       _num_bins,
                       _num_waves,
                       _threads_bins,
                       _off_bins,
                       _PE_bins,
                       _shmem_context);
    }
    num_msgs = (loop + args.skip) * _num_waves * _thread_access;
    num_timed_msgs = loop *  _num_waves * _thread_access;
}

void
RandomAccessTester::verifyResults(uint64_t size)
{
    int offset, i, j;
    for (int k = 0; k < _num_waves; k++) {
        for(i =0; i< _num_bins; i++){
            int index = i + _num_bins * k;
            if(args.op_type == PutType){
                if(_PE_bins[index] == args.myid){
                    offset = _off_bins[index];
                    for(j=0; j<((size/sizeof(int))*args.coal_coef); j++){
                        h_buf[offset+j] = 1;
                    }
                }
            }else{
                if(args.myid ==0){
                offset = _off_bins[index];
                    for(j=0; j<((size/sizeof(int))*args.coal_coef); j++){
                        h_buf[offset+j] = 1;
                    }
                }
            }
        }
    }

    hipMemcpy(h_dev_buf, r_buf, space * args.wg_size * size, hipMemcpyDeviceToHost);
    hipDeviceSynchronize();
    for(i=0;i<(space * args.wg_size * size/sizeof(int)); i++){
    if(h_dev_buf[i]!=h_buf[i]){
            printf("PE %d  Got Data Validation: expecting %d got %d at  %d \n",
                    args.myid, h_buf[i], h_dev_buf[i],  i);
            exit(-1);
        }
    }
}

